Predicting churn rates is very challenging so many data scientist and anlysts struggles in any customer-facing business. Since the user-interaction services like Spotify requires to communicate with customers frequently, there are large amount of logging data every day. Thus, in this project, I would like to show how to manipulate large and realistic datasets with Spark, as well as how to build the prediction model with Spark MLlib. Let's dive in!
Photo by Filip Havlik on Unsplash
First, you need to import lot's of Spark libralies as below, then you can start opening the instance of SparkSession to wranlge the big data.
import os
import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly
from plotly import graph_objs as go
%matplotlib inline
Pretty printing has been turned OFF
# import pyspark libraries
from pyspark.sql import SparkSession
from pyspark.sql import functions as f
from pyspark.sql import Window
from pyspark.sql.types import IntegerType
from pyspark.sql.types import FloatType
from pyspark.sql.types import StringType
from pyspark.sql.types import DateType
from pyspark.ml import Pipeline, PipelineModel, Estimator, Transformer
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import StandardScaler
from pyspark.ml.feature import MinMaxScaler
from pyspark.ml.feature import VectorSlicer
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.classification import LinearSVC
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.classification import GBTClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import CrossValidator
from pyspark.ml.tuning import CrossValidatorModel
from pyspark.ml.tuning import ParamGridBuilder
from pyspark.mllib.evaluation import MulticlassMetrics
# create a Spark session
spark = SparkSession \
.builder \
.appName("Sparkify") \
.getOrCreate()
In this project, we are going to usethe website logging data collected from the virtual music streaming company, called "Sparkify". The original size of this dataset is 12GB, but we can start exploring data with the subset of them -mini_sparkify_event_data.json.
# load data
df = spark.read.json("mini_sparkify_event_data.json")
# printthe basic inforamation of the dataset
print(f"Total records of the data: {df.count()}")
df.printSchema()
df.head()
Total records of the data: 286500 root |-- artist: string (nullable = true) |-- auth: string (nullable = true) |-- firstName: string (nullable = true) |-- gender: string (nullable = true) |-- itemInSession: long (nullable = true) |-- lastName: string (nullable = true) |-- length: double (nullable = true) |-- level: string (nullable = true) |-- location: string (nullable = true) |-- method: string (nullable = true) |-- page: string (nullable = true) |-- registration: long (nullable = true) |-- sessionId: long (nullable = true) |-- song: string (nullable = true) |-- status: long (nullable = true) |-- ts: long (nullable = true) |-- userAgent: string (nullable = true) |-- userId: string (nullable = true)
Row(artist='Martha Tilston', auth='Logged In', firstName='Colin', gender='M', itemInSession=50, lastName='Freeman', length=277.89016, level='paid', location='Bakersfield, CA', method='PUT', page='NextSong', registration=1538173362000, sessionId=29, song='Rockpools', status=200, ts=1538352117000, userAgent='Mozilla/5.0 (Windows NT 6.1; WOW64; rv:31.0) Gecko/20100101 Firefox/31.0', userId='30')
# print the unique categories of each columns to understand data structure
for col in ['auth', 'level', 'page', 'gender','location','artist', 'song',]:
unique_cols = df.select(col) \
.groupBy(col).count() \
.orderBy("count", ascending=False) \
.show()
+----------+------+ | auth| count| +----------+------+ | Logged In|278102| |Logged Out| 8249| | Guest| 97| | Cancelled| 52| +----------+------+ +-----+------+ |level| count| +-----+------+ | paid|228162| | free| 58338| +-----+------+ +--------------------+------+ | page| count| +--------------------+------+ | NextSong|228108| | Home| 14457| | Thumbs Up| 12551| | Add to Playlist| 6526| | Add Friend| 4277| | Roll Advert| 3933| | Login| 3241| | Logout| 3226| | Thumbs Down| 2546| | Downgrade| 2055| | Help| 1726| | Settings| 1514| | About| 924| | Upgrade| 499| | Save Settings| 310| | Error| 258| | Submit Upgrade| 159| | Submit Downgrade| 63| | Cancel| 52| |Cancellation Conf...| 52| +--------------------+------+ only showing top 20 rows +------+------+ |gender| count| +------+------+ | F|154578| | M|123576| | null| 8346| +------+------+ +--------------------+-----+ | location|count| +--------------------+-----+ |Los Angeles-Long ...|30131| |New York-Newark-J...|23684| |Boston-Cambridge-...|13873| |Houston-The Woodl...| 9499| | null| 8346| |Charlotte-Concord...| 7780| |Dallas-Fort Worth...| 7605| |Louisville/Jeffer...| 6880| |Philadelphia-Camd...| 5890| |Chicago-Napervill...| 5114| | St. Louis, MO-IL| 4858| |Phoenix-Mesa-Scot...| 4846| |Vineland-Bridgeto...| 4825| | Wilson, NC| 4659| |Denver-Aurora-Lak...| 4453| | Ionia, MI| 4428| |San Antonio-New B...| 4373| | Danville, VA| 4257| |Atlanta-Sandy Spr...| 4236| |New Haven-Milford...| 4007| +--------------------+-----+ only showing top 20 rows +--------------------+-----+ | artist|count| +--------------------+-----+ | null|58392| | Kings Of Leon| 1841| | Coldplay| 1813| |Florence + The Ma...| 1236| | Dwight Yoakam| 1135| | Björk| 1133| | The Black Keys| 1125| | Muse| 1090| | Justin Bieber| 1044| | Jack Johnson| 1007| | Eminem| 953| | Radiohead| 884| | Alliance Ethnik| 876| | Train| 854| | Taylor Swift| 840| | OneRepublic| 828| | The Killers| 822| | Linkin Park| 787| | Evanescence| 781| | Harmonia| 729| +--------------------+-----+ only showing top 20 rows +--------------------+-----+ | song|count| +--------------------+-----+ | null|58392| | You're The One| 1153| | Undo| 1026| | Revelry| 854| | Sehr kosmisch| 728| |Horn Concerto No....| 641| |Dog Days Are Over...| 574| | Secrets| 466| | Use Somebody| 459| | Canada| 435| | Invalid| 424| | Ain't Misbehavin| 409| | Représente| 393| |Sincerité Et J...| 384| |Catch You Baby (S...| 373| | Yellow| 343| | Somebody To Love| 343| | Hey_ Soul Sister| 334| | The Gift| 327| | Fireflies| 312| +--------------------+-----+ only showing top 20 rows
Now it's time to clean up some empty and invalid data. If you run the code below, you can see that there are some users with empty string, who probably not regular users in Sparkify. So we can discard those users from our dataset. Also, you can see that there are a lot of None values in "song" and other columns, but we don't need to take care of them for now (We are going to transform this dataset for each user so those song and artists data are not critical)
# Show the missing / invalid values in each column
for col in df.columns:
empty_count = df.where(df[col] == "").count()
none_count = df.where(df[col].isNull()).count()
print(f"{col}: \t\tempty({empty_count}), \tnone({none_count})")
artist: empty(0), none(58392) auth: empty(0), none(0) firstName: empty(0), none(8346) gender: empty(0), none(8346) itemInSession: empty(0), none(0) lastName: empty(0), none(8346) length: empty(0), none(58392) level: empty(0), none(0) location: empty(0), none(8346) method: empty(0), none(0) page: empty(0), none(0) registration: empty(0), none(8346) sessionId: empty(0), none(0) song: empty(0), none(58392) status: empty(0), none(0) ts: empty(0), none(0) userAgent: empty(0), none(8346) userId: empty(8346), none(0)
# visualize the missing values for each column
pd_isnull = df.toPandas().isnull().replace({True:1, False:0})
trace = go.Heatmap(
x=pd_isnull.columns.tolist(),
y=pd_isnull.index.tolist(),
z=pd_isnull.values.tolist(),
xgap=0.5,
colorscale=[[0,'black'], [1,'whitesmoke']],
showscale=False,
)
layout = dict(title = dict(
text='Missing Data HeatMap',
x=0.5,
y=0.9,
xanchor='center',
yanchor='top',
font_size=25,
),
plot_bgcolor = 'darkgrey',
paper_bgcolor = 'rgb(243,243,243)',
font = dict(
family='Times New Roman',
size=15,
),
xaxis=dict(
title='columns',
ticks='outside',
tickangle=-45,
side='top'
),
yaxis=dict(
title='index',
showticklabels=False
),
margin=dict(t=200,b=10),
)
fig = go.Figure(data=[trace], layout=layout)
fig.show()